#!/usr/bin/env python3
import numpy as np, pandas as pd
from astropy.io import fits
from astropy.cosmology import FlatLambdaCDM
from astropy.coordinates import SkyCoord
import astropy.units as u
import json, pathlib, sys

COSMO = FlatLambdaCDM(H0=70, Om0=0.3)
RG_EDGES = [5.0, 7.5, 10.0, 12.5, 15.0]   # kpc
MS_EDGES = [10.2, 10.5, 10.8, 11.1]       # log10 M*
PIX_ARCSEC = 0.214                         # OmegaCAM pix→arcsec

# ---------- utils ----------
def first_table_hdu(hdul):
    for h in hdul:
        if hasattr(h, "columns"): return h
    raise SystemExit("No table HDU found in FITS.")

def make_map(names):
    # case-insensitive lookup map: lower -> original
    return {str(n).lower(): str(n) for n in names}

def pick(name_map, candidates):
    for c in candidates:
        if c.lower() in name_map:
            return name_map[c.lower()]
    return None

def label(edges, i): return f"{edges[i]}–{edges[i+1]}"

def to_kpc(theta_arcsec, z):
    if not np.isfinite(theta_arcsec) or not np.isfinite(z) or z <= 0:
        return np.nan
    DA_kpc = COSMO.angular_diameter_distance(z).to_value("kpc")
    return theta_arcsec * np.pi/648000.0 * DA_kpc

def size_arcsec_from(bright_row, bmap):
    """Robust size proxy in arcsec: prefer A_WORLD/B_WORLD (deg), else A_IMAGE/B_IMAGE (pix), else FLUX_RADIUS (pix)."""
    aW = pick(bmap, ["A_WORLD","AWORLD"]); bW = pick(bmap, ["B_WORLD","BWORLD"])
    if aW and bW:
        a = float(bright_row[aW]) * 3600.0
        b = float(bright_row[bW]) * 3600.0
        if np.isfinite(a) and np.isfinite(b) and a>0 and b>0:
            return np.sqrt(a*b)
    aI = pick(bmap, ["A_IMAGE","AIMAGE"]); bI = pick(bmap, ["B_IMAGE","BIMAGE"])
    if aI and bI:
        a = float(bright_row[aI]) * PIX_ARCSEC
        b = float(bright_row[bI]) * PIX_ARCSEC
        if np.isfinite(a) and np.isfinite(b) and a>0 and b>0:
            return np.sqrt(a*b)
    fr = pick(bmap, ["FLUX_RADIUS","R_HALF","RHALF","R50"])
    if fr:
        val = float(bright_row[fr]) * PIX_ARCSEC
        if np.isfinite(val) and val>0:
            return val
    return np.nan

def assign_bins(df):
    rg_i = np.digitize(df["R_G_kpc"].to_numpy(), RG_EDGES, right=False) - 1
    ms_i = np.digitize(df["Mstar_log10"].to_numpy(), MS_EDGES, right=False) - 1
    ok = (rg_i>=0)&(rg_i<len(RG_EDGES)-1)&(ms_i>=0)&(ms_i<len(MS_EDGES)-1)
    df = df.loc[ok].copy()
    df["R_G_bin"]   = [label(RG_EDGES,i) for i in rg_i[ok]]
    df["Mstar_bin"] = [label(MS_EDGES,i) for i in ms_i[ok]]
    return df

def join_by_id(B, L, Bn, Ln):
    bmap, lmap = make_map(Bn), make_map(Ln)
    # likely keys used by KiDS tables
    Bkey = pick(bmap, ["ID","SeqNr","SEQNR","OBJID","OBJECT_ID","SID","SLID","SOURCEID","SOURCE_ID"])
    Lkey = pick(lmap, ["ID","SeqNr","SEQNR","OBJID","OBJECT_ID","SID","SLID","SOURCEID","SOURCE_ID"])
    if not (Bkey and Lkey): return None
    Bd = pd.DataFrame({ "lens_id": B.data[Bkey] })
    Ld = pd.DataFrame({ "lens_id": L.data[Lkey] })
    return Bkey, Lkey, Bd, Ld

# ---------- main ----------
if __name__ == "__main__":
    import argparse
    ap = argparse.ArgumentParser(description="Build KiDS lenses (Bright+LePhare) → data/lenses.csv")
    ap.add_argument("--bright", required=True)
    ap.add_argument("--lephare", required=True)
    ap.add_argument("--out", default="data/lenses.csv")
    ap.add_argument("--max-rows", type=int, default=None)
    ap.add_argument("--coord-join", action="store_true", help="force RA/DEC nearest-neighbour join if no common ID")
    args = ap.parse_args()

    Hb = fits.open(args.bright, memmap=True); B = first_table_hdu(Hb)
    Hl = fits.open(args.lephare, memmap=True); L = first_table_hdu(Hl)
    Bn, Ln = list(B.columns.names), list(L.columns.names)
    bmap, lmap = make_map(Bn), make_map(Ln)

    # RA/DEC/Z from Bright (case-insensitive)
    ra_c  = pick(bmap, ["RAJ2000","ALPHA_J2000","RA","ALPHAWIN_J2000","RA_ICRS","RA_DEG","RADEG","RA"])
    de_c  = pick(bmap, ["DECJ2000","DELTA_J2000","DEC","DELTAWIN_J2000","DE_ICRS","DEC_DEG","DEDEG","DEC"])
    z_c   = pick(bmap, ["Z_B","Z_BEST","Z_PHOT","PHOTOZ","BPZ_Z_B","Z"])
    if not (ra_c and de_c and z_c):
        print("Bright columns not found (RA/DEC/Z). Here are the first 40 column names:", file=sys.stderr)
        print(Bn[:40], file=sys.stderr)
        sys.exit(1)

    # Mass from LePhare (log10 M*)
    m_c = pick(lmap, ["MASS_BEST","MASS_MED","LOGMASS","LOGMSTAR","MSTAR","MASS"])
    if not m_c:
        print("LePhare mass column not found. First 40 columns:", file=sys.stderr)
        print(Ln[:40], file=sys.stderr)
        sys.exit(1)

    # Try ID join
    id_info = join_by_id(B, L, Bn, Ln)
    if id_info:
        Bkey, Lkey, Bd0, Ld0 = id_info
        # Build minimal frames
        Bd = pd.DataFrame({
            "lens_id": B.data[Bkey],
            "ra_deg":  B.data[ra_c].astype(float),
            "dec_deg": B.data[de_c].astype(float),
            "z_lens":  B.data[z_c].astype(float)
        })
        Ld = pd.DataFrame({
            "lens_id": L.data[Lkey],
            "Mstar_log10": L.data[m_c].astype(float)
        })
        D = Bd.merge(Ld, on="lens_id", how="inner")
    else:
        if not args.coord-join:
            print("No common ID key; re-run with --coord-join to match by sky position.", file=sys.stderr)
            sys.exit(1)
        # RA/DEC join within 0.5 arcsec
        print("No common ID; performing coordinate match within 0.5 arcsec...", file=sys.stderr)
        ra_b = B.data[ra_c].astype(float); de_b = B.data[de_c].astype(float)
        ra_l = L.data[pick(lmap, ["RAJ2000","ALPHA_J2000","RA","ALPHAWIN_J2000","RA_ICRS","RA_DEG","RADEG","RA"])].astype(float) \
               if pick(lmap, ["RAJ2000","ALPHA_J2000","RA","ALPHAWIN_J2000","RA_ICRS","RA_DEG","RADEG","RA"]) else None
        de_l = L.data[pick(lmap, ["DECJ2000","DELTA_J2000","DEC","DELTAWIN_J2000","DE_ICRS","DEC_DEG","DEDEG","DEC"])].astype(float) \
               if pick(lmap, ["DECJ2000","DELTA_J2000","DEC","DELTAWIN_J2000","DE_ICRS","DEC_DEG","DEDEG","DEC"]) else None
        if (ra_l is None) or (de_l is None):
            print("LePhare lacks RA/DEC for coord-join. Cannot proceed.", file=sys.stderr); sys.exit(1)

        cB = SkyCoord(ra=ra_b*u.deg, dec=de_b*u.deg)
        cL = SkyCoord(ra=ra_l*u.deg, dec=de_l*u.deg)
        idx, sep2d, _ = cB.match_to_catalog_sky(cL)
        good = sep2d.arcsec <= 0.5
        D = pd.DataFrame({
            "lens_id": np.arange(good.sum(), dtype=int),
            "ra_deg":  ra_b[good],
            "dec_deg": de_b[good],
            "z_lens":  B.data[z_c].astype(float)[good],
            "Mstar_log10": L.data[m_c].astype(float)[idx[good]]
        })

    # Size proxy in kpc
    # We need to draw size columns from the BRIGHT table row-by-row (aligned by lens_id).
    # Build a mapping from lens_id to BRIGHT row index:
    if "lens_id" in D.columns and id_info:
        # map using Bkey indices
        key_to_idx = {B.data[id_info[0]][i]: i for i in range(len(B.data))}
        idxB = [key_to_idx[k] for k in D["lens_id"]]
    else:
        # coord-join path used 'good' indices already — approximate by nearest rows
        # rebuild index by nearest match:
        # (here we do a cheap vectorized nearest match on RA; ok for first pass)
        from numpy import argmin
        idxB = []
        raB = B.data[ra_c].astype(float)
        for r in D["ra_deg"].to_numpy():
            idxB.append(int(np.abs(raB - r).argmin()))

    bmap_full = make_map(B.columns.names)
    sizes = []
    for iB in idxB:
        sizes.append(size_arcsec_from(B.data[iB], bmap_full))
    D["R_G_kpc"] = [to_kpc(s, z) for s, z in zip(sizes, D["z_lens"].to_numpy())]

    # Clean, bin, limit
    D = D.replace([np.inf,-np.inf], np.nan).dropna(subset=["ra_deg","dec_deg","z_lens","Mstar_log10","R_G_kpc"])
    D = assign_bins(D)
    if args.max_rows: D = D.head(args.max_rows)

    # Output with exact headers expected downstream
    out = D.rename(columns={"id":"lens_id"})[["lens_id","ra_deg","dec_deg","z_lens","R_G_kpc","Mstar_log10","R_G_bin","Mstar_bin"]]
    out.to_csv(args.out, index=False)
    pathlib.Path("outputs").mkdir(exist_ok=True)
    json.dump({"R_G_edges_kpc":RG_EDGES,"Mstar_edges":MS_EDGES}, open("outputs/bin_edges.json","w"))
    # Show per-bin counts for a quick QC
    print(f"Wrote {args.out} with {len(out)} rows.")
    try:
        print(out.groupby(["R_G_bin","Mstar_bin"]).size().to_string())
    except Exception:
        pass
